from mynumpy import *

def sample_each_column(pi):
    assert(pi.ndim==2)
    assert(np.allclose(np.sum(pi,axis=0),1.0))
    assert(np.all(pi>=0.0))
    Pi = np.cumsum(pi,axis=0)
    r = rand(pi.shape[1])
    i = np.sum(1-Pi>r,axis=0)
    return i

def one_from_each_col(A,which):
    assert(np.all(which >= 0))
    assert(np.all(which <  A.shape[0]))
    ind = np.ravel_multi_index([which,arange(A.shape[1])],A.shape)
    return A.ravel()[ind]

def logmean_of_logdata(logX):
    return logsumexp(logX-log(len(logX)),axis=0)

def sigmoid(a):
    #return (np.tanh(.5*a)+1)/2
    return 1/(1+exp(-a))